Create a weighted loss function to handle imbalance?

Hi everyone,

I’m working on fine-tuning LLaMA 3.1 8B for a multi-class classification task using text data. However, my dataset is highly imbalanced, which is causing performance issues.

I have already tried several data augmentation techniques like back translation, paraphrasing, and contextual augmentation, but they didn’t significantly improve the balance since they mostly generated near-duplicate copies of the original text.

To address this, I decided to implement a weighted loss function (e.g., focal loss or class-weighted cross-entropy) to better handle class imbalance. However, I’m facing challenges in integrating this into the Hugging Face Trainer class.

Has anyone successfully implemented custom loss functions in Hugging Face’s Trainer? If so, could you guide me on the best way to modify the training loop or integrate focal loss with Trainer? Any help would be greatly appreciated!

Thanks in advance!

2 Likes

There don’t seem to be many examples…

Hi there!

I’ve faced a similar challenge before, so I thought I’d share what worked for me when using a custom loss (like focal loss) with Hugging Face’s Trainer.

1. Create a Custom Trainer

The easiest way to use your own loss function is to subclass the Trainer and override the compute_loss method. In this method, you can compute your loss (for example, focal loss or class-weighted cross-entropy) instead of the default loss. Here’s a simple example:

from transformers import Trainer

class CustomLossTrainer(Trainer):
    def __init__(self, *args, loss_fn=None, **kwargs):
        super().__init__(*args, **kwargs)
        # Store your custom loss function.
        # This should take (logits, labels) as arguments.
        self.loss_fn = loss_fn

    def compute_loss(self, model, inputs, return_outputs=False):
        # Assume your inputs include "labels" and your model returns logits.
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        # Compute the custom loss using your loss function.
        loss = self.loss_fn(logits, labels)
        
        return (loss, outputs) if return_outputs else loss

2. Implement Your Custom Loss Function

You can define your custom loss function (for example, focal loss) as a separate function. For instance, here’s a simple version of focal loss using PyTorch:

import torch
import torch.nn.functional as F

def focal_loss(logits, labels, gamma=2.0, alpha=0.25):
    # Calculate standard cross-entropy loss first.
    ce_loss = F.cross_entropy(logits, labels, reduction='none')
    
    # Get softmax probabilities.
    pt = torch.exp(-ce_loss)
    
    # Compute focal loss.
    focal_loss = alpha * (1 - pt) ** gamma * ce_loss
    return focal_loss.mean()

3. Use the Custom Trainer

When you set up your training, pass your custom loss function to your trainer. For example:

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    logging_steps=50,
    save_steps=500,
)

trainer = CustomLossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_fn=focal_loss,  # Pass your custom loss function here.
)
trainer.train()

I hope this helps you integrate a custom loss function with Hugging Face’s Trainer and improves your model’s performance on imbalanced data. If you have any more questions or need further clarifications, feel free to ask!

Good luck, and happy fine-tuning!

3 Likes

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.